{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os\n",
    "from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTLMHeadModel\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE\n",
    "from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME\n",
    "#from pytorch_pretrained_bert.tokenization import BertTokenizer\n",
    "from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear\n",
    "from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Delete and Generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_tokens = ['<POS>', '<NEG>','<CON_START>','<START>','<END>'] # Set the special tokens\n",
    "tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt', special_tokens=special_tokens)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt', num_special_tokens=len(special_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join(os.getcwd(), \"./log_yelp_bert_best_head_preprocessed/pytorch_model_zero_grad_1.bin\") ## Model Path\n",
    "model_state_dict = torch.load(path, map_location=device)\n",
    "model.load_state_dict(model_state_dict)\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_classifier_dir = \"./bert_classifier/\" \n",
    "model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_dir, num_labels=2)\n",
    "tokenizer_cls = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
    "model_cls.to(device)\n",
    "model_cls.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_len=70\n",
    "sm = torch.nn.Softmax(dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.config.n_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preditction_with_beam_search(ref_text, beam_width=3, vocab_length=40483):\n",
    "    \"\"\"\n",
    "    This function decodes sentences using Beam Seach. \n",
    "    It will output #sentences = beam_width. This function works on a single example.\n",
    "    \n",
    "    ref_text : string : Input sentence\n",
    "    beam_width : int : Width of the output beam\n",
    "    vocab_length : int : Size of the Vocab after adding the special tokens\n",
    "    \"\"\"\n",
    "    \n",
    "    done = [False for i in range(beam_width)] # To track which beams are already decoded\n",
    "    stop_decode = False\n",
    "    decoded_sentences=[] # List of decoded sentences at any given time\n",
    "    \n",
    "    sm = torch.nn.Softmax(dim=-1) # To calculate Softmax over the final layer Logits\n",
    "    tokens = tokenizer.tokenize(ref_text) # Tokenize the input text\n",
    "    \n",
    "    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens) # Convert tokens to ids\n",
    "    index_tokens = [indexed_tokens for i in range(beam_width)] # Replication of Input ids for all the beams\n",
    "\n",
    "    #index_tokens = [indexed_tokens for i in range(beam_width)]\n",
    "    torch_tensor = torch.tensor(index_tokens).to(device)\n",
    "    beam_indexes = [[] for i in range(beam_width)] # indexes of the current decoded beams\n",
    "    best_scoes = [0 for i in range(beam_width)] # A list of lists to store Probability values of each decoded token of best beams\n",
    "    count = 0\n",
    "    while count < model.config.n_positions and not stop_decode:\n",
    "        if count == 0: # For the first step when only one sentence is availabe\n",
    "            with torch.no_grad():\n",
    "                # Calculate output probability distribution over the Vocab,\n",
    "                preds = sm(model(torch_tensor)) #  shape = [beam_bidth, len(input_sen)+1,Vocab_length]\n",
    "            top_v, top_i = preds[:,-1,:].topk(beam_width) # Fatch top indexes and it's values\n",
    "            [beam_indexes[i].append(top_i[0][i].tolist()) for i in range(beam_width)] # Update the Beam indexes\n",
    "            # Update the best_scores, for first time just add the topk values directly\n",
    "            for i in range(beam_width):\n",
    "                best_scoes[i] = top_v[0][i].item()\n",
    "            count += 1\n",
    "        else: # After first step\n",
    "            # Prepare the current_state by concating original input and decoded beam indexes\n",
    "            current_state = torch.cat((torch_tensor, torch.tensor(beam_indexes).to(device)), dim=1)\n",
    "            # Prediction on the current state\n",
    "            with torch.no_grad():\n",
    "                preds = sm(model(current_state))\n",
    "            # Multiply new probability predictions with corresponding best scores\n",
    "            # Total socres = beam_width * Vocab_Size\n",
    "            flatten_score = (preds[:,-1,:]*torch.tensor(best_scoes).to(device).unsqueeze(1)).view(-1)\n",
    "            # Fatch the top scores and indexes \n",
    "            vals, inx = flatten_score.topk(beam_width)\n",
    "            # best_score_inx saves the index of best beams after multiplying the probability of new prediction\n",
    "            best_scoes_inx = (inx / vocab_length).tolist()\n",
    "            best_scoes = vals.tolist()\n",
    "            # Unflatten the index \n",
    "            correct_inx = (inx % vocab_length).tolist()\n",
    "            \n",
    "            # Check if done for all the Beams\n",
    "            for i in range(beam_width):\n",
    "                if correct_inx[i] == tokenizer.special_tokens[\"<END>\"]:\n",
    "                    done[i] = True\n",
    "            # Update the best score for each the current Beams\n",
    "            for i in range(beam_width):\n",
    "                if not done[i]:\n",
    "                    best_scoes[i] = vals.tolist()[i]\n",
    "            # Check is All the Beams are Done\n",
    "            if (sum(done) == beam_width):\n",
    "                stop_decode = True\n",
    "            # Prepapre the new beams\n",
    "            temp_lt=[0 for i in range(beam_width)]\n",
    "            for i,x in enumerate(best_scoes_inx):\n",
    "                temp_lt[i] = beam_indexes[x] + [correct_inx[i]]\n",
    "            # Update the Beam indexes\n",
    "            beam_indexes = temp_lt\n",
    "            del temp_lt\n",
    "            count += 1\n",
    "    # Decode All the beam indexes to till <END> token only and convert into sentence\n",
    "    for i in range(beam_width):\n",
    "        try:\n",
    "            end_index = beam_indexes[i].index(tokenizer.special_tokens[\"<END>\"])\n",
    "        except ValueError:\n",
    "            end_index = len(beam_indexes[i])\n",
    "            \n",
    "        decoded_sentences.append(tokenizer.decode(beam_indexes[i][:end_index]))\n",
    "        \n",
    "    return decoded_sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_best_sentence(input_sentences, sentiment=1):\n",
    "    \"\"\"\n",
    "    This function selects the sentence from the Beam of the sentences,\n",
    "    based on the classification probability score.\n",
    "    \n",
    "    input_sentences : list of strings : Sentences generated by the Beam search decoding\n",
    "    sentiment: int : Expected sentiment (in general class for the classification)\n",
    "    \"\"\"\n",
    "    # BERT pre-processing\n",
    "    ids = []\n",
    "    segment_ids = []\n",
    "    input_masks = []\n",
    "    pred_lt = []\n",
    "    for sen in input_sentences:\n",
    "        text_tokens = tokenizer_cls.tokenize(sen)\n",
    "        tokens = [\"[CLS]\"] + text_tokens + [\"[SEP]\"]\n",
    "        temp_ids = tokenizer_cls.convert_tokens_to_ids(tokens)\n",
    "        input_mask = [1] * len(temp_ids)\n",
    "        segment_id = [0] * len(temp_ids)\n",
    "        padding = [0] * (max_seq_len - len(temp_ids))\n",
    "\n",
    "        temp_ids += padding\n",
    "        input_mask += padding\n",
    "        segment_id += padding\n",
    "        \n",
    "        ids.append(temp_ids)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "    \n",
    "    ids = torch.tensor(ids).to(device)\n",
    "    segment_ids = torch.tensor(segment_ids).to(device)\n",
    "    input_masks = torch.tensor(input_masks).to(device)\n",
    "    # prediction\n",
    "    with torch.no_grad():\n",
    "        preds = sm(model_cls(ids, segment_ids, input_masks))\n",
    "        \n",
    "    preds = preds.tolist()\n",
    "    inx, inx_val = None, 0\n",
    "    for i in range(len(input_sentences)):\n",
    "        temp = preds[i][sentiment]\n",
    "        if temp > inx_val:\n",
    "            inx = i\n",
    "            inx_val = temp\n",
    "    return input_sentences[inx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "op=preditction_with_beam_search(\"<POS> <CON_START> it is not terrible , but it is very good . <START>\",4)\n",
    "op"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_best_sentence(op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predictions for the reference files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"/home/ubuntu/bhargav/data/yelp/processed_files_with_bert_with_best_head/\"\n",
    "os.listdir(output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(output_dir,\"reference_0_predictions_with_beam_search.txt\") ,\"w\", encoding='utf-8') as out_fp:\n",
    "    c = 0\n",
    "    with open(os.path.join(output_dir, \"./reference_0.txt\")) as fp:\n",
    "        for line in fp:\n",
    "            out_sen = preditction_with_beam_search(line.strip(), beam_width=5, vocab_length=max(tokenizer.special_tokens.values()) + 1)\n",
    "            print(c,get_best_sentence(out_sen, sentiment = 1))\n",
    "            c += 1\n",
    "            out_fp.write(get_best_sentence(out_sen, sentiment = 1) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(output_dir,\"reference_1_predictions_with_beam_search.txt\") ,\"w\", encoding='utf-8') as out_fp:\n",
    "    c = 0\n",
    "    with open(os.path.join(output_dir, \"./reference_1.txt\")) as fp:\n",
    "        for line in fp:\n",
    "            out_sen = preditction_with_beam_search(line.strip(), beam_width=5, max(tokenizer.special_tokens.values()) + 1)\n",
    "            print(c,get_best_sentence(out_sen, sentiment = 0))\n",
    "            c += 1\n",
    "            out_fp.write(get_best_sentence(out_sen, sentiment = 0) + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Delete, Retrieve and Generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_tokens = ['<ATTR_WORDS>','<CON_START>','<START>','<END>'] # Set the special tokens\n",
    "tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt', special_tokens=special_tokens)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt', num_special_tokens=len(special_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join(os.getcwd(), \"./log_yelp_retireve_edit_bert_best_head_preprocessing/pytorch_model_zero_grad_1.bin\")\n",
    "model_state_dict = torch.load(path, map_location=device)\n",
    "model.load_state_dict(model_state_dict)\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Output dir have reference files generated using TFIDF for retrieve attributes from opposite corpus\n",
    "output_dir = \"/home/ubuntu/bhargav/data/yelp/processed_files_with_bert_with_best_head/delete_retrieve_edit_model/tfidf/\"\n",
    "os.listdir(output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(output_dir,\"reference_0_predictions_with_full_sentence_match_beam_search_bm5.txt\") ,\"w\", encoding='utf-8') as out_fp:\n",
    "    c = 0\n",
    "    with open(os.path.join(output_dir, \"./reference_0_content.txt\")) as fp:\n",
    "        for line in fp:\n",
    "            out_sen = preditction_with_beam_search(line.strip(), beam_width=5, vocab_length=max(tokenizer.special_tokens.values()) + 1)\n",
    "            print(c,get_best_sentence(out_sen, sentiment = 1))\n",
    "            c += 1\n",
    "            out_fp.write(get_best_sentence(out_sen, sentiment = 1) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(output_dir,\"reference_1_predictions_with_full_sentence_match_beam_search_bm5.txt\") ,\"w\", encoding='utf-8') as out_fp:\n",
    "    c = 0\n",
    "    with open(os.path.join(output_dir, \"./reference_1_content.txt\")) as fp:\n",
    "        for line in fp:\n",
    "            out_sen = preditction_with_beam_search(line.strip(), beam_width=5, vocab_length=max(tokenizer.special_tokens.values()) + 1)\n",
    "            print(c,get_best_sentence(out_sen, sentiment = 0))\n",
    "            c += 1\n",
    "            out_fp.write(get_best_sentence(out_sen, sentiment = 0) + \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
